import numpy as np
import pandas as pd
import os
import sys
import pickle
from joblib import dump, load
import yaml
import statsmodels.api as sm
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.linear_model import LinearRegression as reg
import pdb
from pprint import pprint
from statistics import mean
import random
random.seed(50)
import math
import argparse
sys.path.insert(1,'/REDACTED/fairness/code/rf/scripts/rf')
from gpd_test import *

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib import cm
import matplotlib.font_manager as fm
import matplotlib
import sklearn.metrics as metrics
import seaborn as sns
import pdb
sys.path.insert(1,'/REDACTED/fairness/code/utilities/')
from trajectoryPlotters import *

defaultaxpayer_id= '/REDACTED/data/modeled_refactor_temp/minus_6/'
defaultout = '/REDACTED/'

plt.style.use('/REDACTED/fairness/code/config/fairness.mplstyle')
fe = fm.FontEntry(
    fname='/REDACTED/fairness/code/utilities/cmunrm.ttf',
    name='latex')

fm.fontManager.ttflist.insert(0, fe)
matplotlib.rcParams['font.family'] = fe.name
plt.rcParams['mathtext.default'] = 'regular'

stanford_palette = ['#f0f4f5', '#d0d8da', '#aabeC6', '#009abb', '#007c92', '#09425a',
                   '#c7d1c6', '#80982a', '#556222', '#b96d12', '#53284f', '#5e3032',
                   '#8c1515', '#dddddd', '#cccccc', '#999999', '#666666', '#333333',
                   '#000000']

color_codes = ['background blue', 'gray blue', 'dark gray-blue', 'accent blue', 'link blue', 'dark blue',
              'light green', 'bright green', 'dark green', 'orange', 'purple', 'maroon',
              'cardinal red', 'line gray', 'light gray', 'gray', 'med gray', 'text gray',
              'black']

cdict = dict(zip(color_codes, [mcolors.to_rgba(c) for c in stanford_palette]))
mcolors.get_named_colors_mapping().update(cdict)

###############################################
#### GET PLOT DICTIONARIES
###############################################

### Refundable Credit - REGRESSOR MINUS 6

with open(defaultaxpayer_id + 'eitc_unres_reg_plus_dep_database_output_outcome_ref_cred_amt_dif_pv_new_bifsg.pickle', 'rb') as f:
    unres_eitc_reg_dep_database_ref_cred_minus_6 = pickle.load(f)

with open(defaultaxpayer_id + 'eitc_res_reg_plus_dep_database_output_outcome_ref_cred_amt_dif_pv_new_bifsg.pickle', 'rb') as f:
    res_eitc_reg_dep_database_ref_cred_minus_6 = pickle.load(f)

### Refundable Credit - REGRESSOR

with open('/REDACTED/data/modeled_refactor_temp/eitc_unres_reg_plus_dep_database_output_outcome_ref_cred_amt_dif_pv_new_bifsg.pickle', 'rb') as f:
    unres_eitc_reg_dep_database_ref_cred = pickle.load(f)

with open('/REDACTED/data/modeled_refactor_temp/eitc_res_reg_plus_dep_database_output_outcome_ref_cred_amt_dif_pv_new_bifsg.pickle', 'rb') as f:
    res_eitc_reg_dep_database_ref_cred = pickle.load(f)

### Overclaiming Oracle

with open('/REDACTED/data/modeled_refactor_temp/eitc_unres_oracle_plus_dep_database_output_outcome_ref_cred_amt_dif_pv_new_bifsg.pickle', 'rb') as f:
    unres_eitc_oracle_dep_database_ref_cred = pickle.load(f)



with open('/REDACTED/data/modeled_refactor_temp/eitc_res_oracle_plus_dep_database_output_outcome_ref_cred_amt_dif_pv_new_bifsg.pickle', 'rb') as f:
    res_eitc_oracle_dep_database_ref_cred = pickle.load(f)


file_list = [res_eitc_reg_dep_database_ref_cred_minus_6, unres_eitc_reg_dep_database_ref_cred_minus_6,
             res_eitc_reg_dep_database_ref_cred, unres_eitc_reg_dep_database_ref_cred,
             res_eitc_oracle_dep_database_ref_cred, unres_eitc_oracle_dep_database_ref_cred]

string_list = ['res_eitc_reg_dep_database_ref_cred_minus_6', 'unres_eitc_reg_dep_database_ref_cred_minus_6',
               'res_eitc_reg_dep_database_ref_cred', 'unres_eitc_reg_dep_database_ref_cred',
               'res_eitc_oracle_dep_database_ref_cred', 'unres_eitc_oracle_dep_database_ref_cred']

data_dict = {}

print(len(file_list))
print(len(string_list))

for i in range(len(file_list)):
    print(i)
    data_dict[string_list[i]] = getAllQuantities(file_list[i], bootstrap = True)

## add makeTrajectoryPlot function so we can fiddle with status quo label
def makeTrajectoryPlot(modelnames,
                        quantities,
                        status_quo_point,
                        colors,
                        linestyles,
                        trajname,
                        plot_sq=True,
                        dict_x_varb='bootstrap_revenue',
                        dict_y_varb='full_chen_fair',
                        dict_se_varb='full_chen_fair',
                        sq_y_varb='eitc_chen_fair',
                        x_lab='Detected Underreportaxpayer_idg ($ Millions)',
                        y_lab='Disparity (percentage points)',
                        activity_code=None,
                        out=defaultout,
                        eitc=False,
                        offset=False,
                        annotatePct=True,
                        annotateModelName=False,
                        annotationcoords=[[0,0]],
                        ref_pt_offset=[False],
                        sq_coords_eitc=[500,-0.001],
                        sq_coords_pop=[1000,-0.0005],
                        plot_error = True,
                        title_lab = ''):
    fig, ax = plt.subplots(1, 1, sharex=True, sharey=True, figsize=(10,8))
    fig.suptitle('', fontsize=24)
    plt.xlabel(x_lab, fontsize=20)
    plt.ylabel(y_lab, fontsize=20)
    ## plot oracle
    #xs =  revenues
    xs = [quantities[i][dict_x_varb + '_mean'] for i in range(len(quantities))]
    #ys = disparities
    ys = [quantities[i][dict_y_varb + '_mean'] for i in range(len(quantities))]
    errs = [quantities[i][dict_se_varb + '_std'] for i in range(len(quantities))]
    models = modelnames
    label = [(x/100) for x in range(1,301)]
    ax.set_ylabel(y_lab, fontsize=20)
    ax.set_xlabel(x_lab, fontsize=20)
    ax.set_title(str(title_lab), fontsize=20)
    ax.tick_params(labelsize=16)
    ax.axhline(y=0, color='gray')
    if eitc == True and activity_code is None and plot_sq is True:
        x = status_quo_point['eitc_rev']/1000000
        y = status_quo_point['eitc_' + sq_y_varb]
    elif eitc == False and activity_code is None and plot_sq is True:
        x = status_quo_point['overall_rev']/1000000
        y = status_quo_point['overall_' + sq_y_varb]
    elif activity_code is not None and plot_sq is True:
        x = status_quo_point[str(activity_code) + '_rev']/1000000
        y = status_quo_point[str(activity_code) + '_' + sq_y_varb]
    #plt.plot([plt.xlim()[0],x],[y,y],color='red',linestyle='--')
    #plt.plot([x,x],[0,y],color='red',linestyle='--')
    if plot_sq is True:
        ax.axhline(y=y,color='red',linestyle='dotted')
        ax.axvline(x=x,color='red',linestyle='dotted')
        ax.plot(x,y, 'kx', markersize=10)
    ## scale up y axis to percentages (instead of decimals)
    scale_y=0.01
    ticks_y=ticker.FuncFormatter(lambda y, pos: '{0:g}'.format(y/scale_y))
    ax.yaxis.set_major_formatter(ticks_y)
    if eitc == True and activity_code is None and plot_sq is True:
        ax.annotate('Status quo\n    ' + str(round(status_quo_point['eitc_budget']*100, 2)) + '%', xy=(x, y), xytext=(x+sq_coords_eitc[0] + 500, y+sq_coords_eitc[1] + 0.003), size=16)
    elif eitc == False and activity_code is None and plot_sq is True:
        ax.annotate('Status quo, ' + str(round(status_quo_point['overall_budget']*100, 2)) + '% audit rate', xy=(x, y), xytext=(x+sq_coords_pop[0], y+sq_coords_pop[1]), size=16)
    elif activity_code is not None and plot_sq is True:
        ax.annotate('Status quo, ' + str(round(status_quo_point[str(activity_code) + '_budget']*100, 2)) + '% audit rate', xy=(x, y), xytext=(x+sq_coords_pop[0], y+sq_coords_pop[1]), size=16)
    for i in range(len(models)):
        #x = np.array(quantities[i]['rev_mean'])
        #y = np.array(quantities[i]['fair_mean'])
        x = np.array(xs[i])
        print(x)
        y = np.array(ys[i])
        print(y)
        #errs = 
        ax.plot(x, y, label=models[i], color=colors[i], linestyle=linestyles[i])
        yerr = np.array(errs[i])
        yerr = yerr.astype(np.float)
        if plot_error == True:
            ax.fill_between(x, y-(yerr*1.96), y+(yerr*1.96), alpha=0.2, edgecolor='k', facecolor=colors[i])
        for j, txt in enumerate(label):
            if (j+1)%100==0 and offset==False:
                ax.annotate(str(math.floor(txt)) + '%', xy=(x[j]+50,y[j]), fontsize=16)
                ax.plot(x[j],y[j],'ko',markersize=5)
            elif (j+1)%100==0 and offset==True:
                dist=(float(i)%2)*0.001
                print(dist)
                ax.annotate(str(math.floor(txt)) + '%', xy=(x[j]+50,y[j]-dist), fontsize=16)
                ax.plot(x[j],y[j],'ko',markersize=5)
        if annotatePct and eitc == True and activity_code is None:
            idx=math.floor(round(status_quo_point['eitc_budget'],4)*10000 - 1)
            ax.plot(x[idx],y[idx], 'ko', markersize=10)
            if ref_pt_offset[i]==True:
                ax.annotate(str(round(status_quo_point['eitc_budget']*100, 2)) + '%', xy=(x[idx], y[idx]), xytext=(x[idx]+75, y[idx]-0.0015), size=16)
            else:
                ax.annotate(str(round(status_quo_point['eitc_budget']*100, 2)) + '%', xy=(x[idx], y[idx]), xytext=(x[idx]+75, y[idx]+0.00025), size=16)
        elif annotatePct and eitc == False and activity_code is None:
            idx=math.floor(round(status_quo_point['overall_budget'],4)*10000 - 1)
            ax.plot(x[idx],y[idx],'ko',markersize=10)
            if ref_pt_offset[i] == True:
                ax.annotate(str(round(status_quo_point['overall_budget']*100, 2)) + '%',xy=(x[idx],y[idx]),xytext=(x[idx]+100,y[idx]-0.001),size=16) 
            else:
                ax.annotate(str(round(status_quo_point['overall_budget']*100, 2)) + '%',xy=(x[idx],y[idx]),xytext=(x[idx]+100,y[idx]+0.00075),size=16) 
        elif annotatePct and activity_code is not None:
            idx=math.floor(round(status_quo_point[str(activity_code) + '_budget']*10000 -1, 2))
            ax.plot(x[idx],y[idx], 'ko', markersize=10)
            if ref_pt_offset[i]==True:
                ax.annotate(str(round(status_quo_point[str(activity_code) + '_budget']*100, 2)) + '%', xy=(x[idx], y[idx]), xytext=(x[idx]-200, y[idx]-0.001), size=16)
            else:
                ax.annotate(str(round(status_quo_point[str(activity_code) + '_budget']*100, 2)) + '%', xy=(x[idx], y[idx]), xytext=(x[idx]-200, y[idx]+0.00075), size=16)
        if annotateModelName:
            #pdb.set_trace()
            ax.annotate(modelnames[i],(annotationcoords[i][0],annotationcoords[i][1]), size=20)
    plt.savefig(out+trajname+'.png')
    return fig

###############################################
#### Trajectory Plots
###############################################

###############################################
#### Fig 8: Detected Underreportaxpayer_idg and Disparity by Algorithm
###############################################

### unrestricted probabilistic refundable, classifier, regressor, oracle 
makeTrajectoryPlot(modelnames=['Refundable Credit Prediction', 'Refundable Credit Prediction Minus 6', 'Refundable Credit Oracle'],
                    quantities=[data_dict['unres_eitc_reg_dep_database_ref_cred'], data_dict['unres_eitc_reg_dep_database_ref_cred_minus_6'], data_dict['unres_eitc_oracle_dep_database_ref_cred']],
                    status_quo_point = status_quo_point,
                    colors=['dark blue', 'purple', 'orange'],
                    linestyles=['solid', 'solid', 'solid'],
                    trajname='trajplot_ref_regs_new_x_bootstrap',
                    dict_x_varb = 'bootstrap_total_ref_adj',
                    dict_y_varb='bootstrap_chen_fair',
                    dict_se_varb='bootstrap_chen_fair',
                    sq_y_varb='chen_fair',
                    x_lab='Detected Overclaiming ($ Millions)',
                    y_lab='Disparity (probabilistic, percentage points)',
                    plot_sq = False,
                    activity_code = None,
                    out=defaultout,
                    eitc=True,
                    offset=True,
                    annotatePct=True,
                    annotateModelName=True,
                    ref_pt_offset=[True, True, True],
                    sq_coords_eitc=[-2000,-0.002],
                    sq_coords_pop =[-1000,0.002],
                    annotationcoords=[[100, 0.03], [3300, 0.025], [4000,0.0125]],
                    plot_error = True)

plt.close()